[Kernel] Add FlashInfer MoE A2A Kernel#36022
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the FlashInfer MoE A2A kernel, which is a welcome addition for improving performance in large batch size scenarios. The integration of the new kernel is well-executed across the codebase, including configuration, communicator management, and kernel selection logic. I've identified one high-severity issue related to determining the number of GPUs per node, which could lead to suboptimal performance. My detailed feedback and a suggested fix are in the review comment.
7c6aef4 to
0b13478
Compare
Signed-off-by: Leo Tian <lctian@nvidia.com>
|
Hi @leo-cf-tian, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
|
@leo-cf-tian re-running with your latest commit and without my |
tlrmchlsmth
left a comment
There was a problem hiding this comment.
This looks good to me, assuming we see correctness and are past the issue @elvircrn was running into
|
The trtllm scales issue appears for: and switching to made it go away. Can confirm the int32/int64 index went away in both cases. |
|
thanks @elvircrn. I don't expect many people to set those variables so high, but could be nice to add a warning in case |
tlrmchlsmth
left a comment
There was a problem hiding this comment.
I'd like to get this into v0.18.0, which cuts tomorrow. Could you please fix the pre-commit issues? Looks like they are caused by divergence from main
|
I can help take a look tonight. |
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
|
@tylertitsworth I fixed the merge conflicts. Can you start CI for this PR? |
There was a problem hiding this comment.
@wzhao18 could you hook up this kernel to CI?
needs to be added to .buildkite/test_areas/kernels.yaml
There was a problem hiding this comment.
Sorry I thought I posted the following response but for some reason it was not submitted.
@tlrmchlsmth I re-examined the test and thought that this test may not be too meaningful to add here. It checks the result from _supports_parallel_config with some expectation that is derived from the function itself, which seems kind of redundant. Thus I removed the test from the PR.
I think test_modular_kernel_combinations_multigpu should be a unified test that ensures both that (1) _supports_parallel_config is set correctly and (2) the combination actually works (through testing). However, as far as I checked, this test is not in the CI pipeline and I am having some problems running it even in current main branch. I will look a bit more detail into this and potentially improve it in a future PR.
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Squashed from vllm-project#36022. Signed-off-by: Elvir Crncevic <elvircrn@gmail.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: Leo Tian <lctian@nvidia.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com> Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: Leo Tian <lctian@nvidia.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com> Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: Leo Tian <lctian@nvidia.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com> Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
…ted hidden state scale shape" for EP32+ configs (#2853) ## 📌 Description Fix `int32` overflow in `trtllm_fp4_block_scale_moe` that causes a misleading `NotImplementedError: Unsupported hidden state scale shape` when deploying large Expert Parallel configurations (e.g., EP32 with `DeepSeek-R1 NVFP4`). **Step 1, NVFP4 activation quantization (per EP rank)** Each of the 32 EP ranks quantizes its local activations via `vllm.ops.scaled_fp4_quant` with `is_sf_swizzled_layout=False`. From [nvfp4_quant_entry.cu](https://github.com/vllm-project/vllm/blob/a5e9d511defe2d2dc2dd270674fc197542fc0169/csrc/quantization/fp4/nvfp4_quant_entry.cu): ```cpp output_sf = torch::empty( {m, n / CVT_FP4_SF_VEC_SIZE}, torch::TensorOptions().device(device).dtype(torch::kUInt8)); ``` For m=10240 (`max_num_batched_tokens`), n=7168 (`hidden_size`): `hidden_states`: `[10240, 3584]` `uint8` (FP4 packed, 2 values per byte) `hidden_states_scale`: `[10240, 448]` `uint8` → viewed as `float8_e4m3fn` No padding is applied in the non-swizzled layout. Scale numel = `10240 × 448 = 4,587,520`. **Step 2, EP allgather via dispatch()** `MoEPrepareAndFinalizeNaiveDPEPModular.prepare()` in [naive_dp_ep.py](https://github.com/vllm-project/vllm/blob/a5e9d511defe2d2dc2dd270674fc197542fc0169/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py) calls `get_ep_group().dispatch()`, which allgathers both `hidden_states` and `hidden_states_scale` (passed as `extra_tensors`) across all 32 EP ranks: `hidden_states`: `32 × [10240, 3584]` → [`327680, 3584]` `hidden_states_scale`: `32 × [10240, 448]` → `[327680, 448]` **Step 3, Scale reshape in vLLM wrapper** In [trtllm_nvfp4_moe.py](https://github.com/vllm-project/vllm/blob/a5e9d511defe2d2dc2dd270674fc197542fc0169/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py), the scale is reshaped before passing to flashInfer: ``` hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape( *hidden_states.shape[:-1], -1) # → [327680, 448] ``` At this point `hidden_states_scale.numel()` = 327680 × 448 = 146,800,640. **Step 4, int32 overflow in FlashInfer C++ kernel** In `csrc/trtllm_fused_moe_kernel_launcher.cu`, the scale vector size is computed as: ```cpp int const num_tokens = hidden_states.size(0); // int (32-bit) = 327680 int hidden_size = hidden_states.size(1); // int (32-bit) = 3584 if (hidden_states.dtype() == dl_uint8) hidden_size *= 2; // hidden_size = 7168 hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel(); // ^^^^^^^^^^^^^^^^^^^^^^^^ // int * int = int → OVERFLOW before promotion to int64 for division ``` the overflow: `327680 × 7168 = 2,348,810,240` `INT_MAX` = 2,147,483,647 2,348,810,240 > `INT_MAX`, signed int32 overflow (undefined behavior in C++, wraps to -1,946,157,056 on two's complement architectures) vec_size = -1,946,157,056 / 146,800,640 = -13 -13 ≠ 16 and -13 ≠ 32 will throws "Unsupported hidden state scale shape" Step 5, why not and works Overflow threshold for DeepSeek-R1 (hidden_size=7168): Max safe tokens: INT_MAX / 7168 = 299,593 EP32 per-rank limit: 299,593 / 32 ≈ 9,362 Any max_num_batched_tokens > 9362 with EP32 will trigger the overflow We confirmed the overflow boundary on an 8-node GB200 cluster (32 GPUs, EP32, DP32) with --all2all-backend `allgather_reducescatter`: | max_num_batched_tokens | Total tokens (×32) | M × 7168 | vs INT_MAX | Result | | :--- | :--- | :--- | :--- | :--- | | 9360 | 299,520 | 2,146,560,000 | < 2,147,483,647 | ✅ Success | | 9370 | 299,840 | 2,148,853,760 | > 2,147,483,647 | ❌ **Crash** | | 8192 (Workaround) | 262,144 | 1,879,048,192 | < 2,147,483,647 | ✅ Success | | 10240 (Original) | 327,680 | 2,348,810,240 | > 2,147,483,647 | ❌ **Crash** | **Reproduction** vLLM serve with EP32: ``` vllm serve nvidia/DeepSeek-R1-NVFP4 \ --tensor-parallel-size 1 \ --data-parallel-size 32 \ --enable-expert-parallel \ --all2all-backend allgather_reducescatter \ --max-num-batched-tokens 10240 \ --kv-cache-dtype fp8 \ --trust-remote-code ``` Crashes during engine initialization with: `NotImplementedError: Unsupported hidden state scale shape.` (Also found this issue in vllm-project/vllm#36022 (comment)) Promote the multiplication operands to int64_t before division to prevent overflow: `hidden_states_scale_vec_size`: Cast num_tokens to int64_t so the multiplication chain executes in 64-bit. `weight_scale_vec_size`: Apply the same pattern with local_num_experts cast to int64_t, and declare the variable as int64_t for consistency. Cast the multiplication operands to int64_t before the division: ```cpp // In csrc/trtllm_fused_moe_kernel_launcher.cu // Before (overflow-prone): int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); if (hidden_states.dtype() == dl_uint8) hidden_size *= 2; hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel(); // After (safe): int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); if (hidden_states.dtype() == dl_uint8) hidden_size *= 2; hidden_states_scale_vec_size = (static_cast<int64_t>(num_tokens) * hidden_size) / hidden_states_scale.value().numel(); } ``` The same pattern should also be applied to weight_scale_vec_size for safety: ```cpp int64_t weight_scale_vec_size = (static_cast<int64_t>(local_num_experts) * intermediate_size * intermediate_size_factor * hidden_size) / gemm1_weights_scale.numel(); ``` **Impact** Zero performance impact: these are CPU-side setup computations executed once before GPU kernel launch. Zero API change: No function signatures are modified. Unblocks: EP32+ deployments for large-hidden-size models (DeepSeek-R1, etc.) with max_num_batched_tokens above the int32 threshold. **Environment** Model: DeepSeek-R1-0528-FP4 (NVFP4, hidden_size=7168) Hardware: 8× GB200 nodes (32 GPUs), disaggregated prefill-decode Configuration: DP=32, EP=32, TP=1, PP=1 vLLM: v0.17.2rc1 (bundled FlashInfer) ## 🔍 Related Issues ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed integer overflow in internal size calculations that could cause crashes or incorrect behavior with very large models or batch sizes, improving stability and reliability for large-scale inference. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Albert Cheng (Engrg-Hardware 1) <albecheng@login-lyris01.lyris.clusters.nvidia.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: Leo Tian <lctian@nvidia.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com> Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
Purpose
This PR is a port of PR #32217 to the vLLM top-of-tree after the modular kernel refactors in #32564. It adds the latest TRT-LLM gen A2A kernel from flashinfer's MoE-A2A API (one sided all-to-all) as added in (flashinfer-ai/flashinfer#2102). This should perform better than the older A2A kernel from #21003 (formerly flashinfer_all2allv) in large batch size.
The new kernel can be enabled by specifying
--all2all-backend flashinfer_nvlink_one_sided. It is only available for nvfp4.This PR also renames
flashinfer_all2allvtoflashinfer_nvlink_two_sidedas per suggestion as it is more descriptive and matches the new implementation.We conducted benchmarks and found a noticeable increase in throughput at high concurrency, up to a 14% increase in throughput at 512 concurrency.
Testing
The PR also adds test coverage from @stecasta.
FlashInferMoeA2APrepareAndFinalizein the modular kernel combinatorial test framework (mk_objects.py), enabling automatic multi-GPU testing against all compatible Expert backends with nvfp4 quantizationTrtLlmNvFp4ExpertsModularin the same framework (previously missing from the test registry)_supports_parallel_configincompatibility matrix for the newflashinfer_moe_a2abackend across 7 Expert typesflashinfer_moe_a2aandflashinfer_all2allvshare the same incompatibility matrix, catching drift if one is updated without the otherTest plan
test_supports_parallel_config_flashinfer_moe_a2a— CPU only, 7 parametrized casestest_supports_parallel_config_parity_with_all2allv— CPU only, 7 parametrized casestest_modular_kernel_combinations_multigpu— multi-GPU, auto-generated from mk_objects.py registrationsNotes
The incompatibility matrix tests do not require a GPU and can run in any CI environment. The combinatorial multi-GPU tests require 2x Blackwell GPUs with FlashInfer trtllm_moe_alltoall support.
Reproduction
To reproduce our results, the server can be launched with the following configuration:
To verify correctness, you can run gsm8k: